/*
 * Wazuh Vulnerability scanner
 * Copyright (C) 2015, Wazuh Inc.
 * March 25, 2023.
 *
 * This program is free software; you can redistribute it
 * and/or modify it under the terms of the GNU General Public
 * License (version 2) as published by the FSF - Free Software
 * Foundation.
 */

#include "argsParser.hpp"
#include "contentManager.hpp"
#include "flatbuffers/idl.h"
#include "flatbuffers/include/syscollector_deltas_generated.h"
#include "flatbuffers/include/syscollector_deltas_schema.h"
#include "flatbuffers/include/syscollector_synchronization_generated.h"
#include "flatbuffers/include/syscollector_synchronization_schema.h"
#include "routerModule.hpp"
#include "routerProvider.hpp"
#include "socketClient.hpp"
#include "socketServer.hpp"
#include "stringHelper.h"
#include "vulnerabilityScanner.hpp"
#include <iostream>
#include <thread>
#include <utility>

auto constexpr MAXLEN {65536};
auto constexpr DEFAULT_QUEUE_PATH {"queue/sockets/queue"};
auto constexpr DEFAULT_WDB_SOCKET {"queue/db/wdb"};
auto constexpr DEFAULT_SOCKETS_PATH {"queue/sockets"};
std::mutex G_MUTEX;

/**
 * @brief Class used to create a server for the alerts messages.
 *
 */
class FakeReportServer
{
private:
    int m_socketServer;
    std::thread m_fakeServerThread;
    std::atomic<bool> m_shouldStop {false};
    char m_buffer[MAXLEN] {0};
    size_t m_bytesReceived {0};
    std::string m_path;
    struct sockaddr_un m_clientAddr
    {
    };
    struct sockaddr_un m_serverAddr
    {
        .sun_family = AF_UNIX
    };
    socklen_t m_clientSize;

public:
    /**
     * @brief Construct a new Fake Report Server object
     *
     * @param path Location of the server to create.
     */
    explicit FakeReportServer(std::string path)
    {
        m_socketServer = socket(AF_UNIX, SOCK_DGRAM, 0);
        m_path = std::move(path);
        m_clientSize = sizeof(m_clientAddr);
    }

    /**
     * @brief Destroy the Fake Report Server object and closes the socket.
     *
     */
    ~FakeReportServer()
    {
        waitForStop();
    }

    /**
     * @brief Initiates the main thread.
     *
     */
    void start()
    {
        if (m_socketServer < 0)
        {
            throw std::runtime_error("Failed to create socket. Reason: " + std::string(strerror(errno)));
        }

        if (std::filesystem::exists(m_path))
        {
            std::filesystem::remove(m_path);
        }

        m_fakeServerThread = std::thread(
            [&]()
            {
                std::snprintf(m_serverAddr.sun_path, sizeof(m_serverAddr.sun_path), "%s", m_path.c_str());

                if (bind(m_socketServer, (struct sockaddr*)&m_serverAddr, sizeof(m_serverAddr)) < 0)
                {
                    throw std::runtime_error("Failed to bind socket: " + std::string(strerror(errno)));
                }

                do
                {
                    m_bytesReceived = recvfrom(
                        m_socketServer, m_buffer, MAXLEN - 1, 0, (struct sockaddr*)&m_clientAddr, &m_clientSize);

                    std::lock_guard<std::mutex> lock(G_MUTEX);
                    if (m_bytesReceived > 0)
                    {
                        if (m_bytesReceived > (MAXLEN - 1))
                        {
                            m_bytesReceived = MAXLEN - 1;
                        }
                        m_buffer[m_bytesReceived] = '\0';
                        std::cout << "Fake report server message: " << std::string(m_buffer, m_bytesReceived)
                                  << std::endl;
                    }
                    else
                    {
                        std::cout << "Fake report server err message: " << strerror(errno) << std::endl;
                    }

                } while (!m_shouldStop.load());
            });
    }

    /**
     * @brief Sends the stop signal to the main thread.
     *
     */
    void stop()
    {
        m_shouldStop.store(true);
    }

    /**
     * @brief Waits for the thread to join and cleans.
     *
     */
    void waitForStop()
    {
        if (m_fakeServerThread.joinable())
        {
            m_fakeServerThread.join();
        }

        constexpr auto INVALID_SOCKET {-1};
        if (m_socketServer != INVALID_SOCKET)
        {
            close(m_socketServer);
            m_socketServer = INVALID_SOCKET;
        }

        if (std::filesystem::exists(m_path))
        {
            std::filesystem::remove(m_path);
        }
    }
};

int main(const int argc, const char* argv[])
{
    try
    {
        auto& routerModule = RouterModule::instance();
        auto& vulnerabilityScanner = VulnerabilityScanner::instance();
        CmdLineArgs cmdLineArgs(argc, argv);

        // Read json configuration file
        auto configuration = nlohmann::json::parse(std::ifstream(cmdLineArgs.getConfigurationFilePath()));

        // If the template file path is provided, set in the configuration adding the template path.
        // Otherwise, the default template will be used.
        if (!cmdLineArgs.getTemplateFilePath().empty())
        {
            configuration["indexer"]["template_path"] = cmdLineArgs.getTemplateFilePath();
        }

        routerModule.start();

        auto routerProviderDbSync = RouterProvider("deltas-syscollector", true);
        auto routerProviderRSync = RouterProvider("rsync-syscollector", true);
        auto routerProviderDbUpdate = RouterProvider("wdb-agent-events", true);
        routerProviderDbSync.start();
        routerProviderRSync.start();
        routerProviderDbUpdate.start();

        // Fake Wazuh-DB server
        auto fakeDBServer =
            std::make_shared<SocketServer<Socket<OSPrimitives, SizeHeaderProtocol>, EpollWrapper>>(DEFAULT_WDB_SOCKET);
        nlohmann::json fakeAgentData;
        nlohmann::json fakeGlobalData;
        nlohmann::json fakeAgentPackages;
        nlohmann::json fakeAgentHotfixes;
        if (!cmdLineArgs.getFakeDBServer().empty())
        {
            fakeAgentData = nlohmann::json::parse(std::ifstream(cmdLineArgs.getFakeDBServer()));
            fakeGlobalData = cmdLineArgs.getFakeDBGlobal().empty()
                                 ? nlohmann::json()
                                 : nlohmann::json::parse(std::ifstream(cmdLineArgs.getFakeDBGlobal()));
            fakeAgentPackages = cmdLineArgs.getFakeDBPkgs().empty()
                                    ? nlohmann::json()
                                    : nlohmann::json::parse(std::ifstream(cmdLineArgs.getFakeDBPkgs()));
            fakeAgentHotfixes = cmdLineArgs.getFakeDBHotfixes().empty()
                                    ? nlohmann::json()
                                    : nlohmann::json::parse(std::ifstream(cmdLineArgs.getFakeDBHotfixes()));

            fakeDBServer->listen(
                [&](const int fd, const char* data, uint32_t dataSize, const char*, uint32_t)
                {
                    auto messageReceived = std::string(data, dataSize);
                    auto tokens = Utils::split(messageReceived, ' ');
                    const auto& tableName = tokens[tokens.size() - 1];
                    std::string errMessage = "err Query not supported";
                    if (tokens.size() >= 2)
                    {
                        if (tokens[0] == "global")
                        {
                            std::string successMessage = "ok " + fakeGlobalData.dump();
                            fakeDBServer->send(fd, successMessage.c_str(), successMessage.size());
                        }
                        else if (tokens[0] == "agent" && Utils::isNumber(tokens[1]))
                        {
                            const auto& agentId = tokens[1].length() < 3 ? "00" + tokens[1] : tokens[1];
                            if (tableName.find("sys_programs") != std::string::npos || tokens[2] == "package")
                            {
                                std::string responseMessage = fakeAgentPackages.contains(agentId)
                                                                  ? "ok " + fakeAgentPackages[agentId].dump()
                                                                  : "ok []";
                                fakeDBServer->send(fd, responseMessage.c_str(), responseMessage.size());
                            }
                            else if (tableName.find("sys_osinfo") != std::string::npos || tokens[2] == "osinfo")
                            {
                                std::string responseMessage =
                                    fakeAgentData.contains(agentId) ? "ok " + fakeAgentData[agentId].dump() : "ok []";
                                fakeDBServer->send(fd, responseMessage.c_str(), responseMessage.size());
                            }
                            else if (tableName.find("sys_hotfixes") != std::string::npos || tokens[2] == "hotfix")
                            {
                                std::string responseMessage = fakeAgentHotfixes.contains(agentId)
                                                                  ? "ok " + fakeAgentHotfixes[agentId].dump()
                                                                  : "ok []";
                                std::cout << "Response message: " << responseMessage << std::endl;
                                fakeDBServer->send(fd, responseMessage.c_str(), responseMessage.size());
                            }
                            else
                            {
                                for (const auto& token : tokens)
                                {
                                    std::cout << "Token " << token << std::endl;
                                }
                                std::cout << "Invalid table name: " << tableName << std::endl;
                            }
                        }
                        else
                        {
                            fakeDBServer->send(fd, errMessage.c_str(), errMessage.size());
                        }
                    }
                    else
                    {
                        fakeDBServer->send(fd, errMessage.c_str(), errMessage.size());
                    }
                });
        }

        // Fake alerts server
        FakeReportServer fakeReportServer(DEFAULT_QUEUE_PATH);
        if (cmdLineArgs.getFakeReportServer())
        {
            if (!std::filesystem::exists(DEFAULT_SOCKETS_PATH))
            {
                std::filesystem::create_directories(DEFAULT_SOCKETS_PATH);
            }
            fakeReportServer.start();
        }

        // Open file to write log.
        std::ofstream logFile;
        if (!cmdLineArgs.getLogFilePath().empty())
        {
            logFile.open(cmdLineArgs.getLogFilePath());
            if (!logFile.is_open())
            {
                throw std::runtime_error("Failed to open log file: " + cmdLineArgs.getLogFilePath());
            }
        }

        vulnerabilityScanner.start(
            [&logFile](const int logLevel,
                       const std::string& tag,
                       const std::string& file,
                       const int line,
                       const std::string& func,
                       const std::string& message,
                       va_list args)
            {
                auto pos = file.find_last_of('/');
                if (pos != std::string::npos)
                {
                    pos++;
                }
                std::string fileName = file.substr(pos, file.size() - pos);
                char formattedStr[MAXLEN] = {0};
                vsnprintf(formattedStr, MAXLEN, message.c_str(), args);

                std::lock_guard<std::mutex> lock(G_MUTEX);
                if (logLevel != LOG_ERROR)
                {
                    std::cout << tag << ":" << fileName << ":" << line << " " << func << " : " << formattedStr
                              << std::endl;
                }
                else
                {
                    std::cerr << tag << ":" << fileName << ":" << line << " " << func << " : " << formattedStr
                              << std::endl;
                }

                if (logFile.is_open())
                {
                    logFile << tag << ":" << fileName << ":" << line << " " << func << " : " << formattedStr
                            << std::endl;
                }
                // Flush the log file every time a message is written.
                logFile.flush();
            },
            configuration,
            false,
            !cmdLineArgs.getOnlyDownloadContent(),
            !cmdLineArgs.getDisableContentUpdater());

        if (!cmdLineArgs.getOnlyDownloadContent())
        {
            // Wait for the complete initialization and connection negotiation.
            std::this_thread::sleep_for(std::chrono::seconds(1));

            for (const auto& inputFile : cmdLineArgs.getInputFiles())
            {
                std::cout << "Processing file: " << inputFile << std::endl;
                // Parse inputFile JSON.
                const auto jsonInputFile = nlohmann::json::parse(std::ifstream(inputFile)).dump();

                if (jsonInputFile.find("action") != std::string::npos)
                {
                    std::vector<char> json_vector(jsonInputFile.begin(), jsonInputFile.end());
                    routerProviderDbUpdate.send(json_vector);
                    continue;
                }
                else
                {
                    flatbuffers::Parser parser;
                    std::string jsonContent;
                    bool isDelta = true;
                    if (parser.Parse(syscollector_deltas_SCHEMA))
                    {
                        if (parser.Parse(jsonInputFile.c_str()))
                        {
                            std::cout << "Syscollector delta parsed successfully" << std::endl;
                        }
                        else
                        {
                            if (parser.Parse(syscollector_synchronization_SCHEMA))
                            {
                                if (parser.Parse(jsonInputFile.c_str()))
                                {
                                    isDelta = false;
                                    std::cout << "Syscollector synchronization parsed successfully" << std::endl;
                                }
                                else
                                {
                                    throw std::runtime_error("Failed to parse JSON input file. Reason: " +
                                                             parser.error_);
                                }
                            }
                        }
                    }

                    std::cout << "size: " << parser.builder_.GetSize() << std::endl;
                    // Convert to flatbuffer.
                    std::vector<char> buffer {parser.builder_.GetBufferPointer(),
                                              parser.builder_.GetBufferPointer() + parser.builder_.GetSize()};
                    isDelta ? routerProviderDbSync.send(buffer) : routerProviderRSync.send(buffer);
                }
                // Wait for the complete initialization and connection negotiation.
                std::this_thread::sleep_for(std::chrono::seconds(1));
            }

            if (cmdLineArgs.getWaitTime() > 0)
            {
                std::this_thread::sleep_for(std::chrono::seconds(cmdLineArgs.getWaitTime()));
            }
            else
            {
                std::cout << "Press enter to stop the scanner..." << std::endl;
                std::cin.get();
            }
        }
        else
        {
            // Wait for the start of snapshot processing.
            std::this_thread::sleep_for(std::chrono::seconds(60));
        }

        routerProviderDbSync.stop();
        routerProviderRSync.stop();
        routerProviderDbUpdate.stop();
        vulnerabilityScanner.stop();
        routerModule.stop();
        ContentModule::instance().stop();
        if (cmdLineArgs.getFakeReportServer())
        {
            fakeReportServer.stop();
            // The server has a blocking recv(), sending any message to the socket will unblock it.
            auto tempSocketClient =
                std::make_shared<SocketClient<Socket<OSPrimitives, NoHeaderProtocol>, EpollWrapper>>(
                    DEFAULT_QUEUE_PATH);
            tempSocketClient->connect(
                [](const char* data, uint32_t size, const char* dataHeader, uint32_t sizeHeader) {},
                []() {},
                SOCK_DGRAM);
            tempSocketClient->send("stop", 4);

            fakeReportServer.waitForStop();
            tempSocketClient->stop();
        }
        if (!cmdLineArgs.getFakeDBServer().empty())
        {
            fakeDBServer->stop();
        }
    }
    catch (const std::exception& e)
    {
        std::cerr << e.what() << std::endl;
        CmdLineArgs::showHelp();
        return 1;
    }
    return 0;
}
